import os
import sys
from sklearn.preprocessing import StandardScaler

os.environ["CUDA_VISIBLE_DEVICES"] = "0"    # set cuda device, default 0
PARENT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))
sys.path.append(PARENT_DIR)
sys.path.append(PARENT_DIR + "/model")
sys.path.append(PARENT_DIR + "/config")
sys.path.append(PARENT_DIR + "/tools")
sys.path.append(PARENT_DIR + "/model/net")
sys.path.append(PARENT_DIR + "/model/trainer")

import numpy as np
import pandas as pd
import argparse
import torch
from mmengine import DictAction
from tools.load_config_file import load_config_from_file
torch.multiprocessing.set_sharing_strategy('file_system')
from model.environment.conformal_rl import CRL
from model.log.logger import logger
from model.log.tensorboard import tensorboard_logger
from model.log.wandb import wandb_logger
from model.trainer.rl.ppo_trainer import PPOTrainer
from model.agent.ppo_continuous import PPOContinuous


def parse_args():
    parser = argparse.ArgumentParser(description='Main')
    parser.add_argument("--config", default=os.path.join(PARENT_DIR, "config", "CORE.py"), help="config file path")

    parser.add_argument("--workdir", type=str, default="workdir")
    parser.add_argument("--tag", type=str, default=None)
    parser.add_argument("--log_path", type=str, default=None)
    parser.add_argument("--tensorboard_path", type=str, default=None)
    parser.add_argument("--checkpoint_path", type=str, default=None)
    parser.add_argument("--if_remove", action="store_true", default=False)

    parser.add_argument("--tensorboard", action="store_true", default=True, help="enable tensorboard")
    parser.add_argument("--no_tensorboard", action="store_false", dest="tensorboard")
    parser.set_defaults(writer=True)

    parser.add_argument("--wandb", action="store_true", default=True, help="enable wandb")
    parser.add_argument("--no_wandb", action="store_false", dest="wandb")
    parser.set_defaults(wandb=True)

    parser.add_argument("--device", default="cuda", help="device to use for training / testing")

    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    args = parser.parse_args()
    return args


def run_crl_experiment():

    model = "CORE"
    config = load_config_from_file(file_path=PARENT_DIR + "/config/" + model + ".py")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("model: " + model)
    print("dataset: " + config['dataset'])


    if not os.path.exists(config['work_dir']):
        os.makedirs(config['work_dir'])
    if not os.path.exists(config['checkpoint_path']):
        os.makedirs(config['checkpoint_path'])

    logger.init_logger(config['log_path'], accelerator=None)

    if config['tensorboard']:
        tensorboard_logger.init_logger(config['tensorboard_path'], accelerator=None)
    if config['wandb']:
        wandb_logger.init_logger(
            project=config['project'],
            name=config['tag'],
            config=config,
            dir=config['wandb_path'],
            accelerator=None,
        )

    df_raw = pd.read_csv(config['file_path'] + f"{config['dataset']}/" + config['dataset'] + ".csv")

    cols_data = df_raw.columns[1:]
    data_values = df_raw[cols_data].to_numpy()
    num_train = int(len(data_values) * config['train_ratio'])

    T, D = data_values.shape
    patch_size = config['patch_size']
    num_patches = int(np.ceil(T / patch_size))
    normalized_data = np.zeros_like(data_values)
    stats = []

    for i in range(num_patches):
        start = i * patch_size
        end = min((i + 1) * patch_size, T)
        patch = data_values[start:end]

        if i == 0:
            mean = patch.mean(axis=0)
            std = patch.std(axis=0) + 1e-6
        else:
            mean, std = stats[-1]

        normalized_patch = (patch - mean) / std
        normalized_data[start:end] = normalized_patch

        stats.append((patch.mean(axis=0), patch.std(axis=0) + 1e-6))
    train_data_numpy = normalized_data[:num_train]
    val_data_numpy = normalized_data[num_train:]

    env_train = CRL(train_data_numpy, window_size=config['window_size'], alpha=config['alpha'],
                    calibration_size=config['calibration_size'],
                    score_function="CQR", gamma=config['gamma'], device=device)
    env_val = CRL(val_data_numpy, window_size=config['window_size'], alpha=config['alpha'],
                  calibration_size=config['calibration_size'],
                  score_function="CQR", gamma=config['gamma'], device=device)

    agent = PPOContinuous(env_train, quantile_num=len(config['quantile']),
                          method=config['method'], middle_dim=config['middle_dim'], embed_dim=config['embed_dim'],
                          depth=config['depth'], group_size=config['group_size'], num_head=config['num_head'],
                          window_size=config['window_size'], confidence=config['action_conf'], device=device)
    trainer = PPOTrainer(config=config, agent=agent,
                         data_values=data_values,
                         train_environment=env_train,
                         valid_environment=env_val,
                         test_environment=None,
                         logger=logger,
                         writer=tensorboard_logger,
                         wandb=wandb_logger,
                         device=device)
    trainer.train()




if __name__ == "__main__":
    run_crl_experiment()